﻿// Copyright (c) 2010 Oracle and its affiliates.
//
// MySQL Connector/NET is licensed under the terms of the GPLv2
// <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most 
// MySQL Connectors. There are special exceptions to the terms and 
// conditions of the GPLv2 as it is applied to this software, see the 
// FLOSS License Exception
// <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
//
// This program is free software; you can redistribute it and/or modify 
// it under the terms of the GNU General Public License as published 
// by the Free Software Foundation; version 2 of the License.
//
// This program is distributed in the hope that it will be useful, but 
// WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 
// for more details.
//
// You should have received a copy of the GNU General Public License along 
// with this program; if not, write to the Free Software Foundation, Inc., 
// 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA

using System.Collections;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Net.Sockets;

using HANDLE = System.IntPtr;
using System;
using System.IO;


namespace MySql.Data.MySqlClient
{
    internal class SSPI
    {
        const int SEC_E_OK = 0;
        const int SEC_I_CONTINUE_NEEDED = 0x90312;
        const int SEC_I_COMPLETE_NEEDED = 0x1013;
        const int SEC_I_COMPLETE_AND_CONTINUE = 0x1014;

        const int SECPKG_CRED_OUTBOUND = 2;
        const int SECURITY_NETWORK_DREP = 0;
        const int SECURITY_NATIVE_DREP = 0x10;
        const int SECPKG_CRED_INBOUND = 1;
        const int MAX_TOKEN_SIZE = 12288;
        const int SECPKG_ATTR_SIZES = 0;
        const int STANDARD_CONTEXT_ATTRIBUTES = 0;

        SECURITY_HANDLE outboundCredentials = new SECURITY_HANDLE(0);
        SECURITY_HANDLE clientContext = new SECURITY_HANDLE(0);
        Stream stream;
        String targetName;
        byte[] packetHeader;
        int seq = 3;


        [DllImport("secur32", CharSet = CharSet.Auto)]
        static extern int AcquireCredentialsHandle(
            string pszPrincipal, 
            string pszPackage, 
            int fCredentialUse,
            IntPtr PAuthenticationID,
            IntPtr pAuthData,
            int pGetKeyFn,
            IntPtr pvGetKeyArgument,
            ref SECURITY_HANDLE phCredential,
            ref SECURITY_INTEGER ptsExpiry);

        [DllImport("secur32", CharSet = CharSet.Auto, SetLastError = true)]
        static extern int InitializeSecurityContext(
            ref SECURITY_HANDLE phCredential,
            IntPtr phContext,
            string pszTargetName,
            int fContextReq,
            int Reserved1,
            int TargetDataRep,
            IntPtr pInput, 
            int Reserved2,
            out SECURITY_HANDLE phNewContext,
            out SecBufferDesc pOutput,
            out uint pfContextAttr,
            out SECURITY_INTEGER ptsExpiry);

        [DllImport("secur32", CharSet = CharSet.Auto, SetLastError = true)]
        static extern int InitializeSecurityContext(
            ref SECURITY_HANDLE phCredential,
            ref SECURITY_HANDLE phContext, 
            string pszTargetName,
            int fContextReq,
            int Reserved1,
            int TargetDataRep,
            ref SecBufferDesc SecBufferDesc, 
            int Reserved2,
            out SECURITY_HANDLE phNewContext,
            out SecBufferDesc pOutput,
            out uint pfContextAttr,
            out SECURITY_INTEGER ptsExpiry);

         [DllImport("secur32", CharSet = CharSet.Auto, SetLastError = true)]
        static extern int CompleteAuthToken(
            ref SECURITY_HANDLE phContext,
            ref SecBufferDesc pToken );

        [DllImport("secur32.Dll", CharSet = CharSet.Auto, SetLastError = false)]
        public static extern int QueryContextAttributes(
            ref SECURITY_HANDLE phContext,
            uint ulAttribute,
            out SecPkgContext_Sizes pContextAttributes);

        [DllImport("secur32.Dll", CharSet = CharSet.Auto, SetLastError = false)]
        public static extern int FreeCredentialsHandle(ref SECURITY_HANDLE pCred);

        [DllImport("secur32.Dll", CharSet = CharSet.Auto, SetLastError = false)]
        public static extern int DeleteSecurityContext(ref SECURITY_HANDLE pCred);



        public SSPI(string targetName, Stream stream, int seqNo)
        {
            this.targetName = null;
            this.stream = stream;
            packetHeader = new byte[4];
            seq = seqNo;
        }


        // Read MySQL packet
        // since SSPI blobs data cannot be larger than ~12K,
        // handling just single packet is sufficient
        private byte[] ReadData()
        {
            byte[] buffer;
            MySqlStream.ReadFully(stream, packetHeader, 0, 4);
            int length = (int)(packetHeader[0] + (packetHeader[1] << 8) +
                (packetHeader[2] << 16));
            seq = packetHeader[3]+1;
            buffer = new byte[length];
            MySqlStream.ReadFully(stream, buffer, 0, length);

            return buffer;
        }

        // Write MySQL packet
        private void WriteData(byte[] buffer)
        {
            int count = buffer.Length;

            packetHeader[0] = (byte)(count & 0xff);
            packetHeader[1] = (byte)((count >> 8) & 0xff);
            packetHeader[2] = (byte)((count >> 16) & 0xff);
            packetHeader[3] = (byte)(seq);
            stream.Write(packetHeader, 0, 4);
            stream.Write(buffer, 0, count);
            stream.Flush();
        }

        public void AuthenticateClient()
        {
            bool continueProcessing = true;
            byte[] clientBlob = null;
            byte[] serverBlob = null;
            SECURITY_INTEGER lifetime = new SECURITY_INTEGER(0);
            int ss;

            ss = AcquireCredentialsHandle(null, "Negotiate", SECPKG_CRED_OUTBOUND,
                  IntPtr.Zero, IntPtr.Zero, 0, IntPtr.Zero, ref outboundCredentials,
                  ref lifetime);
            if(ss != SEC_E_OK)
            {
                throw new MySqlException(
                    "AcquireCredentialsHandle failed with errorcode" + ss);
            }
            try
            {
                while (continueProcessing)
                {
                    InitializeClient(out clientBlob, serverBlob,
                        out continueProcessing);
                    if (clientBlob != null && clientBlob.Length > 0)
                    {
                        WriteData(clientBlob);
                        if (continueProcessing)
                            serverBlob = ReadData();
                    }
                }
            }
            finally
            {
                FreeCredentialsHandle(ref outboundCredentials);
                DeleteSecurityContext(ref clientContext);
            }
        }


        void InitializeClient(out byte[] clientBlob, byte[] serverBlob, 
            out bool continueProcessing)
        {
            clientBlob = null;
            continueProcessing = true;
            SecBufferDesc clientBufferDesc = new SecBufferDesc(MAX_TOKEN_SIZE); 
            SECURITY_INTEGER lifetime = new SECURITY_INTEGER(0);
            int ss = -1;
            try
            {
                uint ContextAttributes = 0;

                if (serverBlob == null)
                {
                    ss = InitializeSecurityContext(
                        ref outboundCredentials,
                        IntPtr.Zero,
                        targetName,
                        STANDARD_CONTEXT_ATTRIBUTES,
                        0,
                        SECURITY_NETWORK_DREP,
                        IntPtr.Zero, /* always zero first time around */
                        0,
                        out clientContext,
                        out clientBufferDesc,
                        out ContextAttributes,
                        out lifetime); 

                }
                else
                {
                    String s = System.Text.Encoding.UTF8.GetString(serverBlob, 0, 
                        serverBlob.Length);
                    SecBufferDesc serverBufferDesc = new SecBufferDesc(serverBlob);

                    try
                    {
                        ss = InitializeSecurityContext(ref outboundCredentials,
                            ref clientContext,
                            targetName,
                            STANDARD_CONTEXT_ATTRIBUTES,
                            0,
                            SECURITY_NETWORK_DREP,
                            ref serverBufferDesc,
                            0,
                            out clientContext, 
                            out clientBufferDesc,
                            out ContextAttributes,
                            out lifetime);
                    }
                    finally
                    {
                        serverBufferDesc.Dispose();
                    }
                }


                if ((SEC_I_COMPLETE_NEEDED == ss)
                    || (SEC_I_COMPLETE_AND_CONTINUE == ss))
                {
                    CompleteAuthToken(ref clientContext, ref clientBufferDesc);
                }

                if (ss != SEC_E_OK &&
                    ss != SEC_I_CONTINUE_NEEDED &&
                    ss != SEC_I_COMPLETE_NEEDED &&
                    ss != SEC_I_COMPLETE_AND_CONTINUE)
                {
                    throw new MySqlException(
                        "InitializeSecurityContext() failed  with errorcode "+ss);
                }

                clientBlob = clientBufferDesc.GetSecBufferByteArray();
            }
            finally
            {
                clientBufferDesc.Dispose();
            }
            continueProcessing = (ss != SEC_E_OK && ss != SEC_I_COMPLETE_NEEDED);
        }
    }


    [StructLayout(LayoutKind.Sequential)]
    struct SecBufferDesc : IDisposable
    {

        public int ulVersion;
        public int cBuffers;
        public IntPtr pBuffers; //Point to SecBuffer

        public SecBufferDesc(int bufferSize)
        {
            ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
            cBuffers = 1;
            SecBuffer secBuffer = new SecBuffer(bufferSize);
            pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(secBuffer));
            Marshal.StructureToPtr(secBuffer, pBuffers, false);
        }

        public SecBufferDesc(byte[] secBufferBytes)
        {
            ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
            cBuffers = 1;
            SecBuffer ThisSecBuffer = new SecBuffer(secBufferBytes);
            pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(ThisSecBuffer));
            Marshal.StructureToPtr(ThisSecBuffer, pBuffers, false);
        }

        public void Dispose()
        {
            if (pBuffers != IntPtr.Zero)
            {
                Debug.Assert(cBuffers == 1);
                SecBuffer ThisSecBuffer =
                    (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer));
                ThisSecBuffer.Dispose();
                Marshal.FreeHGlobal(pBuffers);
                pBuffers = IntPtr.Zero;
            }
        }

        public byte[] GetSecBufferByteArray()
        {
            byte[] Buffer = null;

            if (pBuffers == IntPtr.Zero)
            {
                throw new InvalidOperationException("Object has already been disposed!!!");
            }
            Debug.Assert(cBuffers == 1);
            SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers, 
                typeof(SecBuffer));
            if (secBuffer.cbBuffer > 0)
            {
                Buffer = new byte[secBuffer.cbBuffer];
                Marshal.Copy(secBuffer.pvBuffer, Buffer, 0, secBuffer.cbBuffer);
            }
            return (Buffer);
        }

    }

    public enum SecBufferType
    {
        SECBUFFER_VERSION = 0,
        SECBUFFER_EMPTY = 0,
        SECBUFFER_DATA = 1,
        SECBUFFER_TOKEN = 2
    }

    [StructLayout(LayoutKind.Sequential)]
    public struct SecHandle //=PCtxtHandle
    {
        IntPtr dwLower; // ULONG_PTR translates to IntPtr not to uint
        IntPtr dwUpper; // this is crucial for 64-Bit Platforms
    }

    [StructLayout(LayoutKind.Sequential)]
    public struct SecBuffer : IDisposable
    {
        public int cbBuffer;
        public int BufferType;
        public IntPtr pvBuffer;


        public SecBuffer(int bufferSize)
        {
            cbBuffer = bufferSize;
            BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
            pvBuffer = Marshal.AllocHGlobal(bufferSize);
        }

        public SecBuffer(byte[] secBufferBytes)
        {
            cbBuffer = secBufferBytes.Length;
            BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
            pvBuffer = Marshal.AllocHGlobal(cbBuffer);
            Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
        }

        public SecBuffer(byte[] secBufferBytes, SecBufferType bufferType)
        {
            cbBuffer = secBufferBytes.Length;
            BufferType = (int)bufferType;
            pvBuffer = Marshal.AllocHGlobal(cbBuffer);
            Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
        }

        public void Dispose()
        {
            if (pvBuffer != IntPtr.Zero)
            {
                Marshal.FreeHGlobal(pvBuffer);
                pvBuffer = IntPtr.Zero;
            }
        }
    }
    [StructLayout(LayoutKind.Sequential)]
    public struct SECURITY_INTEGER
    {
        public uint LowPart;
        public int HighPart;
        public SECURITY_INTEGER(int dummy)
        {
            LowPart = 0;
            HighPart = 0;
        }
    };

    [StructLayout(LayoutKind.Sequential)]
    public struct SECURITY_HANDLE
    {
        public uint LowPart;
        public uint HighPart;
        public SECURITY_HANDLE(int dummy)
        {
            LowPart = HighPart = 0;
        }
    };

    [StructLayout(LayoutKind.Sequential)]
    public struct SecPkgContext_Sizes
    {
        public uint cbMaxToken;
        public uint cbMaxSignature;
        public uint cbBlockSize;
        public uint cbSecurityTrailer;
    };

}

